{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# ONNX Script 生成 FunctionProto\n",
    "\n",
    "以下示例展示了我们如何在 `onnxscript` 中将 `Selu` 定义为函数。\n",
    "\n",
    "首先，导入用于定义函数的 ONNX `opset`。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnxscript import opset15 as op\n",
    "from onnxscript import script"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义 Selu："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "@script()\n",
    "def Selu(X, alpha: float, gamma: float):\n",
    "    alphaX = op.CastLike(alpha, X)\n",
    "    gammaX = op.CastLike(gamma, X)\n",
    "    neg = gammaX * (alphaX * op.Exp(X) - alphaX)\n",
    "    pos = gammaX * X\n",
    "    zero = op.CastLike(0, X)\n",
    "    return op.Where(zero >= X, neg, pos)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以像下面这样将 ONNXScript 函数转换为 ONNX 函数（FunctionProto）："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_fun = Selu.to_function_proto()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "打印函数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<\n",
      "  domain: \"this\",\n",
      "  opset_import: [\"\" : 15]\n",
      ">\n",
      "Selu <alpha,gamma>(X) => (return_val)\n",
      "{\n",
      "   alpha = Constant <value_float: float = @alpha> ()\n",
      "   alphaX = CastLike (alpha, X)\n",
      "   gamma = Constant <value_float: float = @gamma> ()\n",
      "   gammaX = CastLike (gamma, X)\n",
      "   tmp = Exp (X)\n",
      "   tmp_0 = Mul (alphaX, tmp)\n",
      "   tmp_1 = Sub (tmp_0, alphaX)\n",
      "   neg = Mul (gammaX, tmp_1)\n",
      "   pos = Mul (gammaX, X)\n",
      "   int64_0 = Constant <value: tensor = int64 int64_0 {0}> ()\n",
      "   zero = CastLike (int64_0, X)\n",
      "   tmp_2 = GreaterOrEqual (zero, X)\n",
      "   return_val = Where (tmp_2, neg, pos)\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "import onnx  # noqa: E402\n",
    "\n",
    "print(onnx.printer.to_text(onnx_fun))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
